{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "vsx_ = s_[:,0].data.numpy()\n",
    "y_ = s_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "    vs = torch.longTensor(len_horizon)*3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rc =vehicle_model(1.)\n",
    "rc.forward(path[0],controls_dt[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rc_ =vehicle_model(1./200)\n",
    "rc_.forward(path[0],controls_dt[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class corrected_model:\n",
    "    def __init__(self,p_model,gp_model):\n",
    "        self.p_model = p_model\n",
    "        self.g_model = gp_model\n",
    "        \n",
    "    def foraward(self,satate,control):\n",
    "        _myu,sigma = self.gp_model(state,control)\n",
    "        myu = _myu + self.p_model(statel,control)\n",
    "        return myu,sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "zero2 = torch.zeros(2,2)\n",
    "eye2 = torch.eye(2)\n",
    "B = torch.cat([zero2,eye2,zero2],dim=0)\n",
    "x = torch.arange(6).float()\n",
    "B.t()@x\n",
    "z2s = torch.cat([torch.zeros(2,4),torch.eye(4)],dim=0)\n",
    "s2z = z2s.t()\n",
    "s2z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0,  3,  6,  9, 12, 15, 18, 21, 24, 27, 30])\n",
      "T=\n",
      "0\n",
      "1.482811\n",
      "0.001083556\n",
      "tensor([ 0.2060,  0.5248,  1.0871,  1.9265,  3.0772,  4.5673,  6.4167,  8.6405,\n",
      "        11.2844, 13.5018])\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "1\n",
      "56.616764\n",
      "0.01573755\n",
      "tensor([ 0.2088,  0.5334,  1.1098,  1.9745,  3.1631,  4.7061,  6.6230,  8.8777,\n",
      "        11.1446, 13.0028])\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n",
      "T=\n",
      "2\n",
      "40.859818\n",
      "0.08806099\n",
      "tensor([ 0.2124,  0.5434,  1.1323,  2.0161,  3.2335,  4.8202,  6.7807,  8.9705,\n",
      "        11.0025, 12.9813])\n",
      "torch.Size([6]) torch.Size([3])\n",
      "torch.Size([]) torch.Size([212]) torch.Size([]) torch.Size([212])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7ff65407eba8>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYh0lEQVR4nO3dfZBddX3H8fdXksiG8NC6q7YkYRlLfRirQlZHy4x1RCxVDFr8Q1oduWaMHYtFhd0SaWGwI9rdjshUtM0IF6sQWlFH6lgFbRnHmSpsKD4RqYwGCIXmbh+ACbFJ2m//OPeyd2/u8z17f7/zO5/XTGb3Puw5v9x77vd87/f3cMzdERGR9DwjdANERGR1KMCLiCRKAV5EJFEK8CIiiVKAFxFJ1JoQO52cnPTp6ekQuxYRKazdu3cvuftUv88PEuCnp6dZXFwMsWsRkcIyswcHeb5KNCIiiVKAFxFJlAK8iEiiFOBFRBKlAC8ikigFeBGRRCnAi4gkSgFeRMZnaQkWFrKfjd/vv3/5PslVkIlOIlIyS0tQrcKBA3DVVcv3z83BnXfC176WPXbccVCpwORksKamRAFeRFZPa2C/8kqYn8+CeMPWrfCa12TPmZtToM+RAryI5K9bYG8O2rOzyz+XlrLA3gj0kD2/WlWwH5ICvIjkr1rNgnSnwN7O5OTKQN8I7srqh6YALyL5aGTtlcpyCWaYgNwI9I2/h5VZfeMx6UkBXkRG064DdXY2n0DcmtVv3ZqNuFEm3xcNkxSR0TTKKHB0B2peGoH+ttuyfVWr+e8jQcrgRWQ0o5Rjht2XMvm+KIMXkeE0JipBll2PI9Aqkx9ILgHezE4ys1vN7CdmtsfMXpXHdkUkYo3STIggW6ksl4OaZ8fKCnmVaK4Fvu7ubzWzdcD6nLYrIrFpdKpu3ZrdXo2aey/NI20WFjTCpoORA7yZnQi8GrgQwN0PAYdG3a6IRKq5UzWGgKq6fEd5ZPCnAjWgamYvBXYDF7v7geYnmdl2YDvA5s2bc9itiATR3Kkag0Y2r0z+KHnU4NcAZwCfdvfTgQPAZa1Pcved7j7j7jNTU1M57FZExipEp+ogmuvyAuQT4PcB+9z9e/Xbt5IFfBFJSchO1X601uXV6Tp6icbdHzOzh83s+e5+P3AWcN/oTRORqMRWmukktj6CgPIaRfM+4Kb6CJqfAZEfASLSt+Y1ZooQMItyIhqDXMbBu/u99fr6S9z9ze7+X3lsV0QiEHtpppVKNU/TUgUi0l1RM2KVahTgRaSN5rJMc0ZcJEU9MeVIa9GIyNGKVpZpR6UaZfAi0kZK2W+JSzUK8CKyrGgjZvqR0slqQCrRiMiyFEozrUpcqlEGLyLLUs52S1iqUYAXkTRLM61SPnl1oBJNUTVf5GDQ30VapViaadUo1UxOlubzoAy+SJqzrOavmzDY742r1DePc5ZyK1t2W5JyjQJ87DoF9XYfyEF+b92Wgn25FXUy07DKckJz97H/27Jli0sXtZr7/PzyT1h5u1ZbvX20PiZp03tdKMCiDxBrlcHHqFOmnmeW1byt1mym9euryjnpKkmpoqPEj20F+Fg0H2irFdQ7ad1Ht4Cvck5aylKq6CTxE5wCfGiNwH7gAFx1VXbf7GzYg61bwE/8A1EKKSwklpfET3AaJhlac8CM9XqSzcPLWq97WZLhZkkpw5DIfiU+y1UZfAjdyjGxa834lNEXT+JZ61ASPY4V4ENoPZiKfEC1BovEO60KrQyzVYeV6ElPAX6cGh+wrVuz2ykcTMroi0PvTWeJ9kUowI9TGT5gyujjlWiWmpsEj1UF+NXWqd6eKmX08Uo0S81NgseqAvxqS6nePgxl9OHpNe9PggmYAvxqSbHePgxl9OHpNe9Pgt9wcgvwZnYMsAg84u7n5rXdwtKHqj1l9OOXYGa6qhI6JvPM4C8G9gAn5LjN4tKHqj1l9OOh2arDS+iYzCXAm9lG4I3AR4AP5rHNwtJY48G0OxEmlEEFk1CQGruEkrO8MvhPAHPA8Z2eYGbbge0Amzdvzmm3EdIHazDtsku9hqNLKEiNXULfeEZei8bMzgX2u/vubs9z953uPuPuM1NTU6PuNl6ta7XI4LTezfAarxUsrx8kw0nguMtjsbEzga1mthe4BXitmX0+h+0WR/OB0Lwwlwyn9TXU4lj902uVnwRey5FLNO6+A9gBYGavAS5197ePut1CUUlhdWnkTf9UmslPAq+lxsGPQmPdx0Mjb7rTiJnVkcBrmWuAd/c7gTvz3GbUFGjCUEa/ko5D6UAZ/CgS+ApXSMroV9JxuLoKnEDoik7D0EiFuJRx1I069senwJ2tyuCHUfaMMTZlzOjL8H+MRYG/ISnAD6PAb3gppFqjL9vS07EocGerSjSDUGmmGFIdR9/8/1BZZvwKWPpTBj8IfS0upiJn9Mra41HAz78C/CD0ASumItfoy37BmJgU8PNv7j72nc7MzPji4uLY9ysCHJ3Bx5bRN7cH4mqbBGVmu919pt/nqwbfjwLW3qSLXjX60O+3au2SEwX4fqTSSSfttY6jDxHwm/ehFUnjFfrkPyDV4PtRwNqbDKC1Rt/6frfWwfMq6TRvR7X2YihS/w0K8N3p6kzlNGrA71ZD7xTUlUQUQ8HeJwX4bgp2tpZVMmjAb74NnR9r3k6BJ9OUSsHeJwX4bgp2tpYx6RXw2x037R4rWLCQJrGNvOpAwyRFRAa1sJB9E5ufH+tJetBhksrgOynIGVpEAijIt3sNk+xEQyNFpJOCzE9QBt9JQc7QIiKdKINvpRUjRaQfBZj0pADfSqUZEelHAWKFSjStVJoRkX4UIFZomKSISEFoNclhFaCeJiIyCAX4hgLU00QkQhEnhyPX4M1sE/A3wHMAB3a6+7WjbnfsClBPE5EIRbxmVR6drEeAS9z9HjM7HthtZne4+305bHt8tC6IiAwj4uRw5BKNuz/q7vfUf38S2AOcPOp2xybir1ciUgARz2rNtQZvZtPA6cD32jy23cwWzWyxVqvludvRqPYuIqOKNFHMbRy8mW0Avgi8392faH3c3XcCOyEbJpnXfkcW8dcrESmISOvwuQR4M1tLFtxvcvcv5bHNsVHtXURGFWmiOHKJxswMuB7Y4+4fH71JYxTp1yoRKZhI6/B51ODPBN4BvNbM7q3/e0MO2119qr+LSMJGLtG4+3cAy6Et4xfp1yoRKajILhRU7pmskX6tEpGCiqwqUM7VJCM7y4pIIiKrCpQzg4/sLCsiiYisKlDODD6ys6yIyGooZwYf2VlWRBIS0fDrcgX4iF54EUlURCXgcpVoIp1OLCIJiagEXK4AH9ELLyKJimj5k3IF+IheeBGR1VaeGrzq7yIyLpHEm/IE+Ig6PkQkcZHEm/KUaFR/F5FxiSTemPv4r70xMzPji4uLY9+viEiRmdlud5/p9/nlKdGIiJRM+gE+ks4OESmZCGJP+gE+ks4OESmZCGJP+p2skXR2iEjJRBB71MkqIlIQ6mRtFkENTEQklLQDfAQ1MBEpscBJZto1+AhqYCJSYoFXsE07wGtxMREJKXCSmXaAFxEJKXCSmWYNXp2rIhKLgPEolwBvZueY2f1m9oCZXZbHNkeizlURiUXAeDRyicbMjgGuA84G9gF3m9lt7n7fqNsemjpXRSQWAeNRHhn8K4AH3P1n7n4IuAU4L4ftDq9R95qc7P48lXJEZLX1G49WQR4B/mTg4abb++r3rWBm281s0cwWa7VaDrvtYJCg3frVSQFfRBIytk5Wd9/p7jPuPjM1NbV6Oxqk3lWpwPz88len5r9VsBeRPASMJXkMk3wE2NR0e2P9vjAGqXe1DmFq/tvWCQpLS9l9lUqQr1oiUlABJzvlEeDvBk4zs1PJAvvbgN/LYbvDGWXcafPftp4oAs9IE5GCCtjJmstqkmb2BuATwDHADe7+kW7PL+Rqks0ZPCibF5GxC7KapLt/zd1/3d2f1yu4r5rVrnM194Src1ZECiCdmazjnEzQrXMWFPBFZKVAMSGdtWjGWefq1jkLK+v1jQ5blXNEyitQH146AT7koj6DjMYRkfIJ1NGaxiX7Yh7C2Nq2mNsqIlEr5yX7Yl5crHWasiZTiciYpFGiKdLiYppMJVI+gT7baQT4Il25SZOpRMpHnawl1K1zVtm8SDoCVRmKX4NPqY6tyVQiaQq0ZHDxM/hUyxq9yjfK8EWkh+IH+CJ1sA5Ck6lE0hIgKSt+gC9SB+soNJlKpNgCfE6LG+DLXqLoNhqn7K+NSIwCVBuK28ka8+Smces2mQrUQSsSgwAdrcXN4FOtvedB4+tF4qMa/ADKUnsfhsbXi8RHNXhZFc0Bf2FBwy1FQlANvk+qKQ9PFysRCUM1+D6ppjy8QcbX67UVKbRiBnh1sOanV8BXCUckHwE+S8Us0QRa16EUNORSZHUEGNpdzAxexkclHJF8BKg8FOuSfSoXhNf8HoDeD5ExGusl+8xswcx+YmY/MLMvm9lJo2yvJ81eDU9LGosMJ8DnY9QSzR3ADnc/YmZ/DuwA/nj0ZnWgztW4qHwj0r+iTXRy99ubbn4XeOtozelBs1fjohE4Iv0LkKDm2cn6LuBvOz1oZtuB7QCbN2/OcbcSjdaAr4xeZFmABLVngDezbwLPbfPQ5e7+lfpzLgeOADd12o677wR2QtbJOlRrpViU0YsE1TPAu/vruj1uZhcC5wJneYghORKvXhm9Ar6USdFWkzSzc4A54Lfc/al8miTJUqeslFnROlmBTwLPBO4wM4DvuvsfjNwqSZM6ZaXMitbJ6u6/lldDpITUKStlEqCTtZhr0UiaWpcyBk2eknQEOJYV4CUe7RaR02xZSYUWGxNpoY5ZSYUWGxPpobUjVh2zUiJjXWxMZOx6rVcvEqsCLjYmEpaGWkpRFHAcvEhYGmopRVG0cfAi0Wn3IVJWLzHQOHiREfUz1FJk3AIN71UGL+lTnV5CC1Q6VICX9KlOL6EFuhqdAryUjzJ6GbdAV6NTDV7KR2PpZdxUgxcJRCNvZLWpBi8SSLuvz6rTS55UgxeJiOr0kifV4EUiojq95CXgEtfK4EX6oTq9DCtguU8BXqQfqtPLEGo12PeyCqddARvGXH8HBXiR4Smrly527YJt22DdukkOHZrl+hfABReMtw2qwYsMS+veSAe1Whbc1x9c4t2PL7D+4BLbtmX3j5MyeJE8afSNAHv3wrp1UDlYZYGsjLdz7Sx798LU1PjaoQAvkieteyPA9DQcOgRVshN9lQqHD2f3j5MCvMhqUp2+lKam4HPXLLH7oiq7jq3w1P9Ocv31483eIacavJldYmZuZjpaRZqpTl9a5z9R5eojc3y7UuXBB8ffwQo5ZPBmtgl4PfDQ6M0RKQFl9eVQf383VSoQ6C3NI4O/BpgDPIdtiaRPWX36Ijlhj5TBm9l5wCPu/n0z6/Xc7cB2gM2bN4+yW5H0aPRNWiLpXO8Z4M3sm8Bz2zx0OfAhsvJMT+6+E9gJMDMzo2xfpJlG3yQj9OzVZj0DvLu/rt39ZvYbwKlAI3vfCNxjZq9w98dybaVI2ahOX0hf/Ots5MwtExUeOxJm9mqzoWvw7v5Dd3+2u0+7+zSwDzhDwV0kB6rTF06tBrsvykbOnP9klYMHCTJ7tZnGwYsUhbL6qO3dC7dMVDj85PIEp7VrGfvs1Wa5Bfh6Fi8iq0UrWkbt1OOXeNvBKjup8B/1cZEhZq82UwYvUmQafRONyb/PyjOsgevWz3L4MEFmrzZTgBcpMo2+iUf9JPvBN1V4y5NZ5h4yuIMCvEhaVKcfv+bXd3aWSYJNXD2K1oMXSUm/o28CXic0JbUaPPzheEc3KYMXSV27rF6lnJE1xrx/45lb+e01sOWECueHblQLBXiR1LUbfdNH52ytlg3xi6GWHJvmMe+Hj8BHmWXiA/Dq343rtVKAFymjHp2zy9cTzS5ccf31YWdkxibGMe/tqAYvIlnmPj8PlUo01xON1tISL/zqAocPw18wG82Y93YU4EVkRefs09cTJbueaIUqa9fCvnvL1zFbq8Hdd7ec3KpVNnx4jtveUmViAk44ASYmwo95b0clGhFZodP1RE/7ThU+3NIx21K7T6lu31qm+tw1S5z/RBW2bgXg9EqFB6+N+/+rAC8iK0xNZdnotm2T7Fy7PCNzw9kV2EDH0Ti7Ns4mU7dvlKkOHqyXqqiy56IDcOSq7An1E9wUcQb2BnMf/9LsMzMzvri4OPb9ikj/+srG6xn80psqbD5jkvUHl6hQpUqFpyYmefDBuANgp//j3XfD2WfD44/DpSywwBwfXXclb3/PcWy6ItyEMTPb7e4z/T5fNXgRaWtqCl7+8h4Bul67//mTk0PX7dvWuXPQa7u7dsEpp2SB/JRTstsNpx6/xHsPLPAslqhSYZZ5PvWMizj2T2eLNRvY3cf+b8uWLS4i6di/331iwv1Z1PxS5v1Z1Hxiwv3JK+bdwX1+PntirZb9Xqu5u/vNN2d/d+KJ2c+bb+69n7vuyn5202u7jfbC8r+Nx9ay9jbaCL5jzbyfcEJ/bRsHYNEHiLUK8CKSi0ZQXREQWwJ6I3D6/Lzv358F1cYJAbK/6xS8+z0ZtAverdu9665sO7B8Urp63ZXLJ6N6u2t7an2dUMZFAV5EgumZYTcF/Lvucv+TY7OAfynzDtnJ4Z7bm04KTYG2V9BuaA7ejX/TG2r+0PuWTzS1PTXfsWb+6eDu4H92zJXLGXykBg3wGkUjIrmZmuqvZg8w7XAjFX7B8pDMo4ZjAszNcfBhWLdudkUn7uG1k+y7d4mpe5uWWFha4oVfrXL8/1RYA08/94JfVNn0l3OwiWzFx6a123cdW2HtL2DLJytseE+B6ut9UIAXkSCmpmD+hkm2bZtl7VqY6DIcc+JNFQ59Bt5d78QFuO7w7NFj8xuTkC6AL3wBrj4yx9o1WfDmCY5agydbu32S6enZqEf7DEvDJEUkqH4nR+3aBXPvWuJCr3KjVZi/YZILzm5ZJK1p4tXSEhz8VJWJ91aYfEEamfmgwyQV4EWkMFKaKTuMQQO8SjQiUhg9a/yygiY6iYgkSgFeRCRRCvAiIolSgBcRSZQCvIhIohTgRUQSFWQcvJnVgAfHvuP2JoGiXINMbc1fUdoJautqKEo7IWvrce7e90DRIAE+Jma2OMjEgZDU1vwVpZ2gtq6GorQThmurSjQiIolSgBcRSZQCPOwM3YABqK35K0o7QW1dDUVpJwzR1tLX4EVEUqUMXkQkUQrwIiKJKnWAN7NzzOx+M3vAzC4L3Z5OzGyTmf2Tmd1nZj82s4tDt6kbMzvGzP7FzL4aui3dmNlJZnarmf3EzPaY2atCt6kTM/tA/b3/kZntMrNjQ7epwcxuMLP9Zvajpvt+2czuMLOf1n/+Usg21tvUrp0L9ff/B2b2ZTM7KWQbG9q1temxS8zMzaznVUxKG+DN7BjgOuB3gBcBF5jZi8K2qqMjwCXu/iLglcAfRtxWgIuBPaEb0Ydrga+7+wuAlxJpm83sZOCPgBl3fzFwDPC2sK1a4UbgnJb7LgO+5e6nAd+q3w7tRo5u5x3Ai939JcC/AjvG3agObuTotmJmm4DXAw/1s5HSBnjgFcAD7v4zdz8E3AKcF7hNbbn7o+5+T/33J8kC0clhW9WemW0E3gh8JnRbujGzE4FXA9cDuPshd//vsK3qag0wYWZrgPXAvwVuz9Pc/dvAf7bcfR7w2frvnwXePNZGtdGune5+u7sfqd/8LrBx7A1ro8NrCnANMAf0NTqmzAH+ZODhptv7iDRoNjOzaeB04HthW9LRJ8gOwP8L3ZAeTgVqQLVeTvqMmR0XulHtuPsjwF+QZW2PAo+7++1hW9XTc9z90frvjwHPCdmYPr0L+IfQjejEzM4DHnH37/f7N2UO8IVjZhuALwLvd/cnQrenlZmdC+x3992h29KHNcAZwKfd/XTgAHGUEY5Sr1+fR3ZS+lXgODN7e9hW9c+zsdhRj8c2s8vJSqE3hW5LO2a2HvgQcMUgf1fmAP8IsKnp9sb6fVEys7Vkwf0md/9S6PZ0cCaw1cz2kpW8Xmtmnw/bpI72AfvcvfFN6FaygB+j1wE/d/eaux8GvgT8ZuA29fLvZvYrAPWf+wO3pyMzuxA4F/h9j3di0PPITvDfr3++NgL3mNlzu/1RmQP83cBpZnaqma0j67S6LXCb2jIzI6sV73H3j4duTyfuvsPdN7r7NNnr+Y/uHmWm6e6PAQ+b2fPrd50F3BewSd08BLzSzNbXj4WziLRDuMltwDvrv78T+ErAtnRkZueQlRS3uvtTodvTibv/0N2f7e7T9c/XPuCM+nHcUWkDfL1j5SLgG2Qflr9z9x+HbVVHZwLvIMuI763/e0PoRiXgfcBNZvYD4GXA1YHb01b9W8atwD3AD8k+t9FMsTezXcA/A883s31mtg34GHC2mf2U7BvIx0K2ETq285PA8cAd9c/VXwVtZF2Htg6+nXi/kYiIyChKm8GLiKROAV5EJFEK8CIiiVKAFxFJlAK8iEiiFOBFRBKlAC8ikqj/BzQY43wyyQHBAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from evals import *\n",
    "from optimization import *\n",
    "from gauss_update import *\n",
    "from kinetic_model import *\n",
    "import cubic_spline_planner\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "ax = [0.0, 6.0, 12.5, 10.0, 7.5, 3.0, -1.0]\n",
    "ay = [0.0, -3.0, -5.0, 6.5, 3.0, 5.0, -2.0]\n",
    "goal = [ax[-1], ay[-1]]\n",
    "\n",
    "cx, cy, cyaw, ck, s = cubic_spline_planner.calc_spline_course(\n",
    "        ax, ay, ds=0.2)\n",
    "\n",
    "cx = torch.Tensor(cx).view(-1,1)\n",
    "cy = torch.Tensor(cy).view(-1,1)\n",
    "\n",
    "gx = cx[1:]-cx[:-1]\n",
    "gy = cy[1:]-cy[:-1]\n",
    "\n",
    "\n",
    "len_horizon = 10\n",
    "EMAX=1000\n",
    "kappa=0.01\n",
    "TMAX = 3\n",
    "ql=1.\n",
    "qc=1.\n",
    "step = 0\n",
    "sigma_w=0.01\n",
    "\n",
    "B = torch.cat([torch.zeros(2,2),torch.eye(2),torch.zeros(2,2)],dim=0)\n",
    "\n",
    "p_model = vehicle_model(WB=2.5)\n",
    "gp_model = GP_test(dim=2)\n",
    "gp_propagator = gaussian_propagator(p_model.forward,gp_model.forward,sigma_w)\n",
    "\n",
    "real_car_model = vehicle_model(WB=1.5,noise=True)\n",
    "\n",
    "waypoints = torch.cat([cx[:-1],cy[:-1],gx,gy],dim=1)\n",
    "\n",
    "start=40\n",
    "\n",
    "x0,y0,gx0,gy0 = waypoints[start]\n",
    "v0 = torch.sqrt(gx0.pow(2)+gy0.pow(2))*3\n",
    "yaw0 = torch.atan2(gy0,gx0)\n",
    "delta0=torch.zeros(1)\n",
    "a0=torch.zeros(1)\n",
    "state0 = torch.cat(  [x0.view(1),y0.view(1),yaw0.view(1),v0.view(1),delta0,a0])\n",
    "control0 = torch.Tensor([0.0001,0.0001,0.001])\n",
    "\n",
    "vars0 = torch.ones(len_horizon)*0.1\n",
    "_vars=vars0\n",
    "\n",
    "_controls = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "\n",
    "dt = torch.ones(len_horizon,1)\n",
    "#vs = torch.nn.Parameter(torch.ones(len_horizon)).data.clone()*3.0001\n",
    "vs = torch.LongTensor(torch.arange(len_horizon+1))*3\n",
    "\n",
    "print(vs)\n",
    "real_path = []\n",
    "#print(controls.size())\n",
    "for T in range(TMAX):\n",
    "    print(\"T=\")\n",
    "    print(T)\n",
    "\n",
    "    if T>0:\n",
    "        start = start_\n",
    "        \n",
    "        _controls = controls_.data.clone()\n",
    "    mpc = GP_MPC(gp_propagator,evaluate_,len_horizon,waypoints)\n",
    "    \n",
    "    controls_dt,path,vars_ = mpc.run(state0,_controls,vs,dt,start,_vars)\n",
    "    print(vars_)\n",
    "    controls = controls_dt[:,:-1]\n",
    "    print(state0.size(),controls_dt[0].size())\n",
    "    state_real = real_car_model.forward(state0,controls_dt[0])\n",
    "    real_path.append(state_real)\n",
    "    #break\n",
    "    #print(path[0],state_real)\n",
    "    #path_ = path[step:]\n",
    "    controls_ = _controls.data.clone()\n",
    "    controls_[:-(step+1)] = _controls[step+1:]\n",
    "    controls_[-(step+1):]=0\n",
    "    start_ = search_(state_real,waypoints,start)\n",
    "    \n",
    "    state0 = state_real.data.clone()\n",
    "    _vars = vars_.data.clone()\n",
    "    \n",
    "    \n",
    "x_ = path[:,0].data.numpy()\n",
    "y_ = path[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "a=torch.randn(40,40)\n",
    "a.svd()[1][0],a[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 6])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7ff6e05fc278>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAWbElEQVR4nO3dfYxldX3H8c+3PLTsotJmxoeyLEMspSFWC5kYLYklAoYiGfjDP6TV6LjppklrsaW7BUm6oYmN3Wl8SDQ2E3G0kaxtUCNpqBVsSWNSzM5S8IGFSuguLEKZSVs1C8lC+u0f517mzNn7fM89v4fzfiWbuU9zzm/vnPM93/P9/c7vmLsLAJCfnwvdAADAbBDgASBTBHgAyBQBHgAyRYAHgEydGWKlc3NzvrCwEGLVAJCsI0eObLr7/KifDxLgFxYWtL6+HmLVAJAsMzs+zucp0QBApgjwAJApAjwAZIoADwCZIsADQKYI8ACQKQI8AGSKAA+gOZub0spK8bP7+PHHt15DrYJc6ASgZTY3pbU16eRJ6Y47tl7fv1964AHp3nuL93bulJaXpbm5YE3NCQEewOxUA/uBA9LBg0UQ71pakq68svjM/v0E+hoR4AHUb1BgLwftffu2fm5uFoG9G+il4vNrawT7CRHgAdSjG9S7QXn//v6BvZe5ue2BvrwcsvqJEOAB1KMbjKWtEswkAbkb6MvLKWf13fcwFAEewHS6mfvSUvG8G9TrCMTVrH5pqRhxQyY/EgI8gOmUM/dZZdfdQL+yQiY/BgI8gOmUyzFNrYtMfiRc6ARgMt0LlaQim24i0HYz+XvuKTL5tbXZrzNhtQR4MzvPzO42s8fM7KiZvb2O5QKIWLc0EyLILi9vjc4pXx2Lbeoq0Xxa0jfd/T1mdrakHTUtF0BsenWqNq3ciUtdvq+pA7yZvUbSOyR9UJLc/ZSkU9MuF0CkmuhUHQd1+b7qyOAvkrQhac3M3iLpiKSb3f1k+UNmtlfSXknavXt3DasFEESTnaqjYIRNX3XU4M+UdLmkz7n7ZZJOSrq1+iF3X3X3RXdfnJ+fr2G1ABoVolN1HOW6PCTVE+BPSDrh7t/tPL9bRcAHkJOQnaqjqNbl6XSdvkTj7s+Z2dNmdom7Py7pKkmPTt80AFGJrTTTT2x9BAHVNYrmw5Lu6oygeVJS5FsAgJGVJxFLIWCmciBqQC3j4N394U59/c3ufqO7/08dywUQgdhLM1WUal7BVAUABks1I6ZUQ4AH0EO5LFPXzJBNS/XAVCPmogFwutTKMr1QqiGDB9BDTtlvi0s1BHgAW1IbMTOKnA5WY6JEA2BLDqWZqm6pZm6udTNPksED2JJ7ttuycg0BHkCepZlecj+AVVCiSVH1NLP8fJTHQFWOpZleWjayhgw+Jd0s6+RJ6Y47itf27dt+2ikNf9y9S303Y5O2j3lG+7Qss21LqYYAH7tyIO5ulAcObJ8WtdfOOexxv4NCxhs7Bkj1YqZJteSAZu7e+EoXFxd9fX298fUmo1dQ7wb0ujLtQRl89SpG5Iu/dVLM7Ii7L476eTL4GJWz63KmUWeWVV1W+XH19JUgkK+WlCr6ynzbJsDHoryhzSqoj6p6+lo94GS8Q7ROS0oVfWV+gCPAh9av4zTkxlY9qJSDQOY7RCvkMJFYXTI/wDFMMrRywIz1fpLlKwGr971k+GV62jIkchSZD5skgw9hUDkmdtWMj4w+PZlnrRPJdDsmwIdQ3ZhS3qCqwSLzTqukteVq1UlketAjwDepu4MtLRXPc9iYyOjTwd+mv0z7IgjwTWrDDkZGH69Ms9RaZba9EuBnrV+9PVdk9PHKNEutVWbbKwF+1nKqt0+CjD48vvPRZZaEEeBnJcd6+yTI6MPjOx9dZmc5tQV4MztD0rqkZ9z9+rqWmyx2qt7I6JuXWVY6cxltk3Vm8DdLOirp1TUuM13sVL2R0TeDq1Unl9E2WUuAN7Ndkt4t6WOS/qSOZSaLscbj6XUgzCiDCiajINW4jJKzujL4T0naL+lV/T5gZnsl7ZWk3bt317TaCLFjjadXdsl3OL2MglTjMjrjmXouGjO7XtLz7n5k0OfcfdXdF919cX5+ftrVxqs6VwvGx3w3k+t+V9LW/EGYTAbbXR2TjV0hacnMjkn6iqR3mtmXa1huOsobQnliLkym+h0yOdbo+K7qk8F3OXWJxt1vk3SbJJnZlZL+1N3fN+1yk0JJYbYYeTM6SjP1yeC7ZBz8NBjr3gxG3gzGiJnZyOC7rDXAu/sDkh6oc5lRI9CEQUa/Hdsh+iCDn0YGp3BJIqPfju1wthJOILij0yQYqRCXNo66oWO/OQl3tpLBT6LtGWNs2pjRt+H/GIuEz5AI8JNI+A/eCrnW6Ns29XQsEu5spUQzDkozach1HH35/0FZpnkJlv7I4MfBaXGaUs7oydrjkeD+T4AfBztYmobV6GMO+G2/YUxMEtz/CfDjSLgWh5LqjhpbZkbWHqcE938C/ChizvAwvuqOGlsJh6wdNaGTdRS5dNKht2Gdsk10rpXXwYyk8Uqso5UMfhScJrfLsBJOXRl+eTlk7WmIrZw3BAF+EO7O1E7DSjjDAn75effzvd4rL4ckIg2J/Z0I8IMkdrTGjIwb8MvPpf7vlZeTYAdeKyX2dyLAD5LY0RoNGRbwe203vd5LLFigJHRH/IjM3Rtf6eLioq+vrze+XgCoxcpKcSZ28GCjB2kzO+Lui6N+ngy+n0SO0AACSOTsnmGS/TA0EkA/icwFRAbfTyJHaADohwy+ihkjAYwq8gufCPBVlGYAjCryeEGJporSDIBRRR4vGCYJAIkYd5gkJZquyGtpACIVcewgwHdFXksDEKmIY8fUNXgzu0DS30p6nSSXtOrun552uY2LvJYGIFIRx46pa/Bm9gZJb3D3h8zsVZKOSLrR3R/t9zvU4AFgfI3X4N39WXd/qPP4Z5KOSjp/2uU2JuL6GQBMo9YavJktSLpM0nd7vLfXzNbNbH1jY6PO1U4n4voZgEREmijWNg7ezM6V9FVJH3H3n1bfd/dVSatSUaKpa71Ti7h+BiARkd47opYAb2ZnqQjud7n71+pYZmOYkxvAtCJNFKcu0ZiZSbpT0lF3/8T0TWpQpKdVABIT6eySddTgr5D0fknvNLOHO/+uq2G5s0f9HUDGpi7RuPt3JFkNbWlepKdVABIV2Y2C2n0la6SnVQASFVlVoJ2zSUZ2lAWQiciqAu3M4CM7ygLIRGRVgXZm8JEdZQFgFtqZwUd2lAWQkYiGX7crwEf0xQPIVEQl4HaVaCK9nBhARiIqAbcrwEf0xQPIVETTn7QrwEf0xQPArLWnBk/9HUBTIok37QnwEXV8AMhcJPGmPSUa6u8AmhJJvJn6nqyT4J6sADC+xu/JCgCIU/4BPpLODgAtE0HsyT/AR9LZAaBlIog9+XeyRtLZAaBlIog9dLICQCLoZC2LoAYGAKHkHeAjqIEBaLHASWbeNfgIamAAWizwDLZ5B3gmFwMQUuAkM+8ADwAhBU4y86zB07kKIBYB41EtAd7MrjWzx83sCTO7tY5lToXOVQCxCBiPpi7RmNkZkj4r6RpJJyQdNrN73P3RaZc9MTpXAcQiYDyqI4N/q6Qn3P1Jdz8l6SuSbqhhuZPr1r3m5gZ/jlIOgFkbNR7NQB0B/nxJT5een+i8to2Z7TWzdTNb39jYqGG1fYwTtKunTgR8ABlprJPV3VfdfdHdF+fn52e3onHqXcvL0sGDW6dO5d8l2AOoQ8BYUscwyWckXVB6vqvzWhjj1LuqQ5jKv1u9QGFzs3hteTnIqRaARAW82KmOAH9Y0sVmdpGKwP5eSb9Tw3InM8240/LvVg8Uga9IA5CogJ2stcwmaWbXSfqUpDMkfcHdPzbo80nOJlnO4CWyeQCNCzKbpLvf6+6/6u5vHBbcZ2bWda5yTzidswASkM+VrE1eTEDnLIBxBIoL+cxF02Sdi85ZAOMI1IeXT4APOakPnbMABgnU0ZrHLftizpKrbYu5rQCi1s5b9sU8uVj1MmXq9QAakkeJJqXJxajXA+0UYP/OI8CndOcm6vVAOwXYv/MI8KkaNBqHbB7IS4BKQ9o1+Nxq2FxMBeQrwLTBaWfwOZc0hpVvyPABDJF2gE+pc3Vcg8o30vaA3+2wJdgD8aKTdUwpda5Oa5yrZwHEh07WMbS9RDFoNE7bvxsgRnSyjiHmi5uaNuhiKokOWiAGdLKOIef6+7QYXw/Ehxr8GNpUfx8X4+uB+FCDx0yUA/7KCsMtgRCowY+ImvLkBt2sROK7BWaFGvyIqClPbpzx9Xy3QNLSDPB0sNZnWMCnhAPUI8C+lGaJJsCpTmswfz0wGwGGdqeZwaM5XDEL1CNA5SGtW/ZRLgiL2w8CQTV6yz4zWzGzx8zse2b2dTM7b5rlDcXVq2FxxSwwuQD7x7Qlmvsk3ebuL5vZX0m6TdKfTd+sPuhcjQsjcIDRpXahk7t/q/T0QUnvma45Q3D1alwYgQOMLkCCWmcn64ck/V2/N81sr6S9krR79+4aV4toVAM+GT2wJUCCOjTAm9n9kl7f463b3f0bnc/cLullSXf1W467r0palYpO1olai7SQ0QNBDQ3w7n71oPfN7IOSrpd0lYcYkoN4DcvoCfhok9RmkzSzayXtl/Rb7v5CPU1CtuiURZul1skq6TOSfl7SfWYmSQ+6++9P3SrkiU5ZtFlqnazu/it1NQQtRKcs2iRAJ2uac9EgT9WpjCUunkI+AmzLBHjEo9ckclwti1ww2RhQQccscsFkY8AQTHiGFmt0sjGgccMmPANileBkY0BYDLVEKhIcBw+ExVBLpCK1cfBAdHrtRGT1iAHj4IEpjTLUEmhaoOG9ZPDIH3V6hBaodEiAR/6o0yO0QHejI8Cjfcjo0bRAd6OjBo/2YSw9mkYNHgiEkTeYNWrwQCC9Tp+p06NO1OCBiJDVo07U4IGIMJ4edQk4xTUZPDAqRt9gEgHLfQR4YFSMp8ckAtXfJQI8MDnq9BhFoPq7RA0emBx1egwT+BaTZPBAnajToyxwGY8AD9SJOj3KAtbfJQI8MFvU6dsrgr9zLTV4M7vFzNzM2FqBMur07RXB33nqDN7MLpD0LklPTd8coAXI6tshcHlGqieD/6Sk/ZK8hmUB+SOrz18kB+ypAryZ3SDpGXd/ZITP7jWzdTNb39jYmGa1QH6Wl6WDB7ePvqkMr9vYkA4fLn4icpEcsIeWaMzsfkmv7/HW7ZI+qqI8M5S7r0palaTFxUWyfaBsyOibQ4ekPXuks8+WTp2S7rxTuummME3FCCIoz0iSuU8Wa83s1yV9W9ILnZd2SfqxpLe6+3ODfndxcdHX19cnWi/QCqVT/A2f04UXSjte3NSy1rSmZb1wzpyOH5fm50M3FE0ysyPuvjjq5ycu0bj79939te6+4O4Lkk5IunxYcAcwglKd/tixInNf1ppWtF/LWtNZZ0nHjoVuJE4T+MrVKsbBA5FbWCjKMmsqTvfXtKyXXipej6UzDx2RXdhW21w0nUw+jsMWkJH5+aLm/sI5c1p99T69cM6c7ryzU56JpDMPKg62J09KBw4Er713MdkYkICbbpKOH5fuv7/4+UoHK6Nv4rG2Jt1xh7RzZzRnU5RogETMz/foVGX0TTwiGTlTRgYP5KSU0W9sFMF9x4ub+r2frGjHi5vas4dMvnbdsybp9AvYAiPAAzkZdfRNZKM9khZxPwglGiBTA0ffRDbaI2kRlma6yOCBTA0cfUPn7PQiLs10EeCBjPUdfVOd8KxSZjh0SLrwQumaa4qfhw6FaX/UIi7NdFGiATLXc/RNVanMUO6cXX6xmBphz545XX01UyNI2rq4bGmpeB5haaaLAA9g23DLY4c7nbMvFp2zkrR61j6deHhT8w9z1WxK/RcEeADb9Oucvfg7a9JfVAJbm6ZKSChz7yLAA9im2zm7Z8+cVs/ap5deKp6fe82ydK62B7aEstmJdQP7yZPFlapSMv/XiacLngbTBQPx29goxswvLAyovVcz+Bwz+pWV4iB24EAxDUHA/9u40wWTwQPoaaTO2SFTJSQb8MvtLo9zT+n/IIZJAqhTdXx9dShhzFfQlttWbneve+gmggweQH2qGX31Ks9eNfvQWX6vGnvEV6eOgwwewOxUs99qhi8NzvLryviry+mVrUtbbUs4ay8jgwfQnGqGLw3O8qXBNf1Bz7vLWl4+/cyh/DzhGvswBHgAYQ0r65QfDwrU1edS7yBe/dnroJMJhkkCSMekGXwmmfm4wyQJ8ACQiHEDPJ2sAJApAjwAZIoADwCZIsADQKYI8ACQKQI8AGSKAA8AmQoyDt7MNiQdb3zFvc1JinBqu55oa/1SaadEW2chlXZKRVt3uvvId8YNEuBjYmbr41w4EBJtrV8q7ZRo6yyk0k5psrZSogGATBHgASBTBHhpNXQDxkBb65dKOyXaOguptFOaoK2tr8EDQK7I4AEgUwR4AMhUqwO8mV1rZo+b2RNmdmvo9vRjZheY2b+Y2aNm9kMzuzl0mwYxszPM7N/N7B9Ct2UQMzvPzO42s8fM7KiZvT10m/oxsz/u/O1/YGaHzOwXQrepy8y+YGbPm9kPSq/9kpndZ2Y/6vz8xZBt7LSpVztXOn//75nZ183svJBt7OrV1tJ7t5iZm9nQu5i0NsCb2RmSPivptyVdKukmM7s0bKv6elnSLe5+qaS3SfqDiNsqSTdLOhq6ESP4tKRvuvuvSXqLIm2zmZ0v6Y8kLbr7mySdIem9YVu1zRclXVt57VZJ33b3iyV9u/M8tC/q9HbeJ+lN7v5mSf8h6bamG9XHF3V6W2VmF0h6l6SnRllIawO8pLdKesLdn3T3U5K+IumGwG3qyd2fdfeHOo9/piIQnR+2Vb2Z2S5J75b0+dBtGcTMXiPpHZLulCR3P+Xu/xu2VQOdKekcMztT0g5JPw7cnle4+79K+u/KyzdI+lLn8Zck3dhoo3ro1U53/5a7v9x5+qCkXY03rIc+36kkfVLSfkkjjY5pc4A/X9LTpecnFGnQLDOzBUmXSfpu2Jb09SkVG+D/hW7IEBdJ2pC01iknfd7MdoZuVC/u/oykv1aRtT0r6Sfu/q2wrRrqde7+bOfxc5JeF7IxI/qQpH8M3Yh+zOwGSc+4+yOj/k6bA3xyzOxcSV+V9BF3/2no9lSZ2fWSnnf3I6HbMoIzJV0u6XPufpmkk4qjjHCaTv36BhUHpV+WtNPM3he2VaPzYix21OOxzex2FaXQu0K3pRcz2yHpo5L+fJzfa3OAf0bSBaXnuzqvRcnMzlIR3O9y96+Fbk8fV0haMrNjKkpe7zSzL4dtUl8nJJ1w9+6Z0N0qAn6Mrpb0n+6+4e4vSfqapN8M3KZh/svM3iBJnZ/PB25PX2b2QUnXS/pdj/fCoDeqOMA/0tm/dkl6yMxeP+iX2hzgD0u62MwuMrOzVXRa3RO4TT2ZmamoFR9190+Ebk8/7n6bu+9y9wUV3+c/u3uUmaa7PyfpaTO7pPPSVZIeDdikQZ6S9DYz29HZFq5SpB3CJfdI+kDn8QckfSNgW/oys2tVlBSX3P2F0O3px92/7+6vdfeFzv51QtLlne24r9YG+E7Hyh9K+icVO8vfu/sPw7aqryskvV9FRvxw5991oRuVgQ9LusvMvifpNyT9ZeD29NQ5y7hb0kOSvq9iv43mEnszOyTp3yRdYmYnzGyPpI9LusbMfqTiDOTjIdso9W3nZyS9StJ9nf3qb4I2sqNPW8dfTrxnJACAabQ2gweA3BHgASBTBHgAyBQBHgAyRYAHgEwR4AEgUwR4AMjU/wPGZ9hRkpIw5QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "real_path_ = torch.cat(real_path,dim=0).view(-1,6)\n",
    "\n",
    "print(real_path_.size())\n",
    "x_ = real_path_[:,0].data.numpy()\n",
    "y_ = real_path_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_ = path[:,0].data.numpy()\n",
    "y_ = path[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x_,y_,c='b',s=40)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "wx,wy = cx[vs+start].data.cpu().numpy(),cy[vs+start].data.cpu().numpy()\n",
    "wx,wy=s0[:,0],s0[:,1]\n",
    "plt.scatter(wx,wy,c='k',s=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ql=1.\n",
    "qc=1.\n",
    "cv = 0.5\n",
    "def evaluate_(path,refs,vs,variances):\n",
    "    mean_error = error_deviation_parallel_(path,refs,ql,qc)\n",
    "    var_error = ( variances*(path[:,:2]-refs[:,:2]).pow(2).sum(dim=1) ).sum()\n",
    "    #print(mean_error,var_error)\n",
    "     \n",
    "    return mean_error + cv*var_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate=0.01\n",
    "dim_s=2\n",
    "\n",
    "class GP_MPC:\n",
    "    def __init__(self,gp_propagator,evaluate,len_horizon,waypoints):\n",
    "        self.waypoints = waypoints\n",
    "        self.len_horizon = len_horizon\n",
    "        self.gp_propagator = gp_propagator\n",
    "        self.evaluate = evaluate\n",
    "        \n",
    "    def run(self,state_init,controls_init,vs,dt,start,_vars):\n",
    "        start = torch.LongTensor([start])\n",
    "        state_init = state_init.data.clone()\n",
    "        controls = torch.nn.Parameter(controls_init.data.clone())\n",
    "        vs = torch.LongTensor(vs.data.clone())\n",
    "        if len(dt)==1:\n",
    "            dt = torch.ones(self.len_horizon).view(-1,1)*dt\n",
    "        opt =  torch.optim.Adam([controls],lr=learning_rate)\n",
    "        refs = waypoints[vs[1:]+start]\n",
    "        _,sigma_init = self.gp_propagator.initialize(state_init,controls[0])\n",
    "        for epoch in range(EMAX):\n",
    "            controls_dt = torch.cat([controls,dt],dim=1)\n",
    "            state  = state_init\n",
    "            sigma = sigma_init\n",
    "            \n",
    "            path = []\n",
    "            if epoch == EMAX-1:\n",
    "                vars_ = []\n",
    "                for t in range(len_horizon):\n",
    "                    state,sigma = self.gp_propagator.forward(state,controls_dt[t],sigma,True)\n",
    "                    sigma_xy = sigma[:dim_s,:dim_s]\n",
    "                    _,eigs,_ = sigma.svd()\n",
    "                    var = eigs[0]                   # the largest eigenvalue of sigma_xy \n",
    "                    path.append(state.view(1,-1))\n",
    "                    vars_.append(var.view(1))\n",
    "                path = torch.cat(path,dim=0)\n",
    "                vars_ = torch.cat(vars_,dim=0)\n",
    "            else:\n",
    "                for t in range(len_horizon):\n",
    "                    state,_ = self.gp_propagator.forward(state,controls_dt[t],sigma,False)\n",
    "                    path.append(state.view(1,-1))\n",
    "                path = torch.cat(path,dim=0)\n",
    "            \n",
    "            opt.zero_grad()\n",
    "            loss = self.evaluate(path,refs,vs,_vars) \n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            if epoch % 500 ==0:\n",
    "                print(loss.data.numpy())\n",
    "        controls_dt = torch.cat([controls,dt],dim=1)\n",
    "        return controls_dt,path,vars_.data.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
